import numpy as np
import argparse
import joblib
from utils import uni_normal, create_oracle, realize

class Environment:
    
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

    def new(self,t):
        pass

    def get_features(self):
        pass
        
    def recommend(self, *args):
        pass

    def play(self, *args):
        pass

class MovieLensEnv(Environment):
    def __init__(self, **kwargs):

        super().__init__(**kwargs)

        self.X_random = np.random.RandomState(seed=self.seed)
        self.theta_random = np.random.RandomState(seed=self.seed + 1)
        self.mnl_random = np.random.RandomState(seed=self.seed + 2)
        self.target_random = np.random.RandomState(seed=self.seed + 3)
        self.user_random = np.random.RandomState(seed=self.seed + 4)

        self.theta = joblib.load(f'logistic_theta.pkl')

        self.oracle = create_oracle(self.oracle_type, self.K) # function
        self.realize_cascade = realize(self.K, self.mnl_random) # function

        self.arm_features = joblib.load(f'logistic_features.pkl')

        self.num_users = len(self.arm_features)

    def new(self,t):
        self._reset_arms(t)

    def get_features(self):
        return self.X
        
    def recommend(self, utility):
        return self.oracle(utility)

    def play(self, t, ACT):
        OPT = self._get_OPT(t)

        OPT_utility = self.X[OPT].dot(self.theta) # (K,)

        ACT_utility = self.X[ACT].dot(self.theta) # (K,)

        OPT_ereward, OPT_reward, _, _ = self.realize_cascade(OPT_utility)
        ACT_ereward, ACT_reward, Y, stop = self.realize_cascade(ACT_utility)

        regret = OPT_ereward - ACT_ereward

        return Y, stop, regret

    def _reset_arms(self, t):
        _i = self.X_random.choice(range(self.num_users))
        
        self.X = self.arm_features[_i]
        self._i = _i

        
    def _get_OPT(self, t):
        OPT = self.oracle(self.X.dot(self.theta))
        return OPT


if __name__ == "__main__":

    pass